# -----------------------------------------
# Run Heterogeneity permutation test
# -----------------------------------------

# Uses the original BART library instead of modified version of BART
# running when R version > 3x
# here, we use R version 3.0.3

rm(list = ls(all = TRUE))

library(BayesTree)

# set up the working directory in your computer
setwd("")

run.hetero <- function(case,cc) {
  covar <- "c_female"

if (case==1) {load(file = paste("id1","US", covar, ".RData", sep = ""))} 
if (case==2) {load(file = paste("id1","GB", covar, ".RData", sep = ""))}
if (case==3) {load(file = paste("id2",cntry.list[cc], covar, ".RData", sep = ""))}
if (case==4) {load(file = paste("id2","ESS", covar, ".RData", sep = ""))}

temp.X <- makeind(X)
S0 <- temp.X
S1 <- temp.X

S0[,"Tr.0"] <- 0
S1[,"Tr.0"] <- 1

temp.S <- rbind(S0,S1)
y <- as.numeric(y)

set.seed(123456788)
out.bart <- bart(x.train = temp.X, y.train = y, ndpost = 1000, nskip = 100, ntree=200, x.test = temp.S)



tau <- t((out.bart$yhat.test[,(nrow(X)+1):(2*nrow(X))]) - 
         (out.bart$yhat.test[,1:nrow(X)]))

#res <- apply(tau,1,mean)
res <- tau
if (case==1) {save(res, file = "thetid1_US.RData")}
if (case==2) {save(res, file = "thetid1_GB.RData")}
if (case==3) {save(res, file = paste("thetid2_",cntry.list[cc],".RData",sep=""))}
if (case==4) {save(res, file = "thetid2_ESS.RData")}

}


# Here, randomly permuted tests begin;
permute.hetero <- function(case,cc) {
  covar <- "c_female"
  
  if (case==1) {load(file = paste("id1","US", covar, ".RData", sep = ""))} 
  if (case==2) {load(file = paste("id1","GB", covar, ".RData", sep = ""))}
  if (case==3) {load(file = paste("id2",cntry.list[cc], covar, ".RData", sep = ""))}
  if (case==4) {load(file = paste("id2","ESS", covar, ".RData", sep = ""))}
  
  temp.X <- makeind(X)
  
  for (i in 1:10) {
  X.perm <- temp.X
  sel <- sample(1:nrow(X.perm),replace=FALSE)
  X.perm <- X.perm[sel,]
  X.perm[,"Tr.0"] <- temp.X[,"Tr.0"]
  
  S0 <- X.perm
  S1 <- X.perm
  
  S0[,"Tr.0"] <- 0
  S1[,"Tr.0"] <- 1
  
  S.perm <- rbind(S0,S1)
  y <- as.numeric(y)
  
  set.seed(123456788)
  out.bart <- bart(x.train = X.perm, y.train = y, ndpost = 1000, nskip = 100, ntree=200, x.test = S.perm)
  
  tau <- t((out.bart$yhat.test[,(nrow(X)+1):(2*nrow(X))]) -
             (out.bart$yhat.test[,1:nrow(X)]))

  if (case==1) {save(tau, file = paste("perm_",i,"_thetid1_US.RData",sep=""))}
  if (case==2) {save(tau, file = paste("perm_",i,"_thetid1_GB.RData",sep=""))}
  if (case==3) {save(tau, file = paste("perm_",i,"_thetid2_",cntry.list[cc],".RData",sep=""))}
  if (case==4) {save(tau, file = paste("perm_",i,"_thetid2_ESS.RData",sep=""))}
  } # for loop  (i)
}  # function ends



cntry.list <- c("US","GB")
permute.hetero(1,0)
permute.hetero(2,0)
permute.hetero(3,1)
permute.hetero(4,0)

run.hetero(1,0)
run.hetero(2,0)
run.hetero(3,1)
run.hetero(4,0)


# ===================================================================
# Figure 4
# ===================================================================
load("thetid1_US.RData");   pid0.us   <- tau/3
load("thetid1_GB.RData");   pid0.gb   <- tau/3
load("thetid2_US.RData");   ppid0.us  <- tau/100
load("thetid2_ESS.RData");  ppid0.ess <- tau/100
Y <- list();XT <- list()

PY <- apply(pid0.us,1,mean); tau.temp <- matrix(NA,nrow=length(PY),ncol=10)
for(i in 1:10)  {load(paste("perm_",i,"_thetid1_US", ".RData", sep = ""));tau.temp[,i] <- apply(tau/3,1,mean)}
Y[[1]] <- PY; XT[[1]] <- tau.temp

PY <- apply(ppid0.us,1,mean); tau.temp <- matrix(NA,nrow=length(PY),ncol=10)
for(i in 1:10)  {load(paste("perm_",i,"_thetid2_US", ".RData", sep = ""));tau.temp[,i] <- apply(tau/100,1,mean)}
Y[[2]] <- PY; XT[[2]] <- tau.temp

PY <- apply(pid0.gb,1,mean); tau.temp <- matrix(NA,nrow=length(PY),ncol=10)
for(i in 1:10)  {load(paste("perm_",i,"_thetid1_GB", ".RData", sep = ""));tau.temp[,i] <- apply(tau/3,1,mean)}
Y[[3]] <- PY; XT[[3]] <- tau.temp

PY <- apply(ppid0.ess,1,mean); tau.temp <- matrix(NA,nrow=length(PY),ncol=10)
for(i in 1:10)  {load(paste("perm_",i,"_thetid2_ESS", ".RData", sep = ""));tau.temp[,i] <- apply(tau/100,1,mean)}
Y[[4]] <- PY; XT[[4]] <- tau.temp


graph.hetero <- function(Y,XT,m.title){
ymax <- max(density(Y)$y)
for (i in 1:10) {
  if (max(density(XT[,i])$y)>ymax) { ymax <-max(density(XT[,i])$y) }
}

xmin <- min(density(Y)$x)
xmax <- max(density(Y)$x)
for (i in 1:10) {
  if (max(density(XT[,i])$x)>xmax) { xmax <-max(density(XT[,i])$x) }
  if (min(density(XT[,i])$x)<xmin) { xmin <-min(density(XT[,i])$x) }
}

plot(density(Y), lwd=2, sub="",main=m.title,xlab=expression(widehat("CATE")),
     ylim=c(0,ymax), xlim=c(xmin,xmax))
for (i in 1:10) lines(density(XT[,i]), lwd = 1.5, col = "grey80")
}


png("hetero_effect.png", height=2500, width=3000, res=250)
layout(matrix(c(1:4), 2, 2, byrow=F))
graph.hetero(Y[[1]],XT[[1]],"US : Heterogeneous Treatment Effects on Party ID")
graph.hetero(Y[[2]],XT[[2]],"US : Heterogeneous Treatment Effects on Political Ideology")
graph.hetero(Y[[3]],XT[[3]],"UK : Heterogeneous Treatment Effects on Party ID")
graph.hetero(Y[[4]],XT[[4]],"Europe : Heterogeneous Treatment Effects on Political Ideology")
dev.off()

